
import datetime
import os
import random
import time
from collections import deque
from itertools import count
import types

import hydra
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import wandb
from omegaconf import DictConfig, OmegaConf
from tensorboardX import SummaryWriter

from make_envs import make_env
from memory import Memory
from utils import eval_mode, get_concat_samples, evaluate, soft_update, hard_update
from logger import Logger

torch.set_num_threads(2)
class BCNet(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim):
        super(BCNet, self).__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.device="cpu"
        self.model= torch.nn.Sequential(nn.Linear(obs_dim,hidden_dim),
                                        nn.ReLU(),
                                        nn.Linear(hidden_dim,hidden_dim), 
                                        nn.ReLU(),
                                        nn.Linear(hidden_dim,action_dim),
                                        ) 

    def forward(self, obs):

        x = self.model(obs)

        return x

    def choose_action(self, state, sample=False):
        if isinstance(state, LazyFrames):
            state = np.array(state) / 255.0
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        with torch.no_grad():
            q = self.forward(state)
            dist = F.softmax(q, dim=1)
            # if sample:
            dist = Categorical(dist)
            action = dist.sample()  # if sample else dist.mean
            # else:
            #     action = torch.argmax(dist, dim=1)

        return action.detach().cpu().numpy()[0]

def get_args(cfg: DictConfig):
    cfg.device = "cuda:0" if torch.cuda.is_available() else "cpu"
    cfg.hydra_base_dir = os.getcwd()
    print(OmegaConf.to_yaml(cfg))
    return cfg


@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig):
    args = get_args(cfg)
    wandb.init(project=args.project_name, entity= #TODO your name here
               sync_tensorboard=True, reinit=True, config=args)

    # set seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device(args.device)
    if device.type == 'cuda' and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    ts_str = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d_%H-%M-%S")
    log_dir = os.path.join(args.log_dir, args.env.name, args.exp_name, args.method.type, str(args.seed), ts_str)
    writer = SummaryWriter(log_dir=log_dir)
    env_args = args.env
    env = make_env(args)
    eval_env = make_env(args)

    # Seed envs
    env.seed(args.seed)
    eval_env.seed(args.seed + 10)

    REPLAY_MEMORY = int(env_args.replay_mem)
    INITIAL_MEMORY = int(env_args.initial_mem)
    EPISODE_STEPS = int(env_args.eps_steps)
    EPISODE_WINDOW = int(env_args.eps_window)
    LEARN_STEPS = int(env_args.learn_steps)
    INITIAL_STATES = 128  # Num initial states to use to calculate value of initial state distribution s_0

    agent_bc = BCNet(env.observation_space.shape[0], env.action_space.n, 128)
    bc_optimizer = torch.optim.Adam(agent_bc.parameters(), lr=1e-4
                                     ,betas=[0.9, 0.999])
    # # Load expert data
    # import pdb; pdb.set_trace()
    # if args.env.name == "CartPole-v1":
    #     dataset = "CartPole-v1_1000"
    # elif args.env.name == "Acrobot-v1":
    #     dataset = "Acrobot-v1_1000.pkl"
    # elif args.env.name == "LunarLander-v2":
    #     dataset = "LunarLander-v2_1000.npy"
    expert_memory_replay = Memory(REPLAY_MEMORY//2, args.seed)
    expert_memory_replay.load(hydra.utils.to_absolute_path(f'experts/{args.env.demo}'),
                              num_trajs=args.eval.demos,
                              sample_freq=args.eval.subsample_freq,
                              seed=args.seed + 42)
    print(f'--> Expert memory size: {expert_memory_replay.size()}')
    expert_batch = expert_memory_replay.get_samples(REPLAY_MEMORY//2, device)

    expert_obs, expert_next_obs, expert_action, expert_reward, expert_done = expert_batch
    epochs =10000
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        predicted_actions = agent_bc(expert_obs)
        loss = criterion(predicted_actions, expert_action.long().flatten())
        bc_optimizer.zero_grad()
        loss.backward()
        bc_optimizer.step()
        print(loss.item())
                    
    eval_returns, eval_timesteps = evaluate(agent_bc, eval_env, num_episodes=args.eval.eps)
    returns = np.mean(eval_returns)
    returns_std = np.std(eval_returns)
    writer.add_scalar('Rewards/eval_rewards', returns,  
                                  global_step=1)
    writer.add_scalar('Rewards/std_rewards', returns_std,  
                                  global_step=1)
    print("Returns:", returns)
    with open("../../../pickle_results/bc/"+args.env.name+str(args.seed)
                        +"n_trajs"+str(args.eval.demos)
                        +"_lr_w"+str(args.agent.critic_lr)
                        +"_lr_theta"+str(args.agent.critic_lr)+".pt","wb") as f:
                    print("Saving Pickle")
                    pickle.dump((returns, returns_std), f)
if __name__ == "__main__":
    main()

